
# Pretrained model path, CompVis/stable-diffusion-v1-4
export MODEL_NAME="/data_nvme0n1p2/hemuhui/checkpoint/pretrained/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

# Get params
DEVICE_ID=$1
output_dir_name=$2

output_dir="../checkpoint/sd-pokemon-model/"
output_dir="${output_dir}${output_dir_name}"

mkdir ${output_dir}
if [ $? -ne 0 ]; then
  exit 1
fi

python examples/text_to_image/train_text_to_image.py \
  --device_id=$DEVICE_ID \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --checkpointing_steps=3000 \
  --output_dir="${output_dir}" | tee ${output_dir}/train.log


# backup params
#  --use_ema \
#  --mixed_precision="fp16" \
#  --use_jit